{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Solving small linear systems via SVD.\n",
    "case 0. Invertible matrix (demonstrates that SVD\n",
    "        is equivalento matrix inversion in the\n",
    "        simplest case)\n",
    "case 1. Non square matrix (over or under determined \n",
    "        system). Here matrix inversion does not work.\n",
    "        SVD finds the best approximation.\n",
    "\"\"\"\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "%run 4.3.2-common.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Solution via inverse: tensor([1.0000, 2.0000, 3.0000])\n",
      "Solution via SVD: tensor([1.0000, 2.0000, 3.0000])\n"
     ]
    }
   ],
   "source": [
    "# Case 0: Invertible square matrix\n",
    "A = torch.tensor([[1, 2, 1], [2, 2, 3], [1, 3, 3]], dtype=torch.float)\n",
    "b = torch.tensor([8, 15, 16], dtype=torch.float)\n",
    "\n",
    "# Solve this via matrix inversion\n",
    "# i.e x = inv(A) *  b\n",
    "x_0 = torch.matmul(torch.linalg.inv(A), b)\n",
    "print(\"Solution via inverse: {}\".format(x_0))\n",
    "\n",
    "# Computation of inverse can be unstable. \n",
    "# Now solve this via SVD instead.\n",
    "U, S, V_t = torch.linalg.svd(A)\n",
    "\n",
    "# A*x = b ==> U*S*V_t*x = b\n",
    "#         ==> S * V_t*x = U_t*b \n",
    "# (inv(U) is U_t, U being orthogonal)\n",
    "S_V_t_x = torch.matmul(U.T, b)\n",
    "\n",
    "# V_t*x = inv(S) * U_t * b  \n",
    "# Inv(S) is easily computable as S is diagonal\n",
    "S_inv = torch.diag(1 / S) \n",
    "V_t_x = torch.matmul(S_inv, S_V_t_x)\n",
    "\n",
    "# x= V * inv(S) * U_t * b \n",
    "x_1 = torch.matmul(V_t.T, V_t_x)\n",
    "print(\"Solution via SVD: {}\".format(x_1))\n",
    "\n",
    "# assert that the two solutions are more\n",
    "# or less identical\n",
    "assert torch.allclose(x_0, x_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Solution via pseudo-inverse:\n",
      "tensor([ 1.0766,  0.8976, -0.9582])\n",
      "Solution via SVD:\n",
      "tensor([ 1.0766,  0.8976, -0.9582])\n"
     ]
    }
   ],
   "source": [
    "# Case 1: Non square matrices\n",
    "# Let us revisit our cat brain dataset \n",
    "\n",
    "A = torch.tensor([[0.11, 0.09], [0.01, 0.02],\n",
    "              [0.98, 0.91], [0.12, 0.21],\n",
    "              [0.98, 0.99], [0.85, 0.87],\n",
    "              [0.03, 0.14], [0.55, 0.45],\n",
    "              [0.49, 0.51], [0.99, 0.01],\n",
    "              [0.02, 0.89], [0.31, 0.47],\n",
    "              [0.55, 0.29], [0.87, 0.76],\n",
    "              [0.63, 0.24]])\n",
    "A = torch.column_stack((A, torch.ones(15))) \n",
    "b = torch.tensor([-0.8, -0.97, 0.89, -0.67,\n",
    "              0.97, 0.72, -0.83, 0.00,\n",
    "              0.00, 0.00, -0.09, -0.22,\n",
    "              -0.16, 0.63, 0.37])\n",
    "\n",
    "# One way to solve for this is via pseudo-inverse\n",
    "# i.e x = pseudo_inv(A) *  b\n",
    "x_0 = torch.matmul(torch.linalg.pinv(A), b)\n",
    "print(\"Solution via pseudo-inverse:\\n{}\".\\\n",
    "      format(x_0))\n",
    "\n",
    "# Computation of inverse can be unstable. \n",
    "# So let us try solving this via SVD instead\n",
    "\n",
    "# Now solve this via SVD instead.\n",
    "U, S, V_t = torch.linalg.svd(A,\n",
    "                          full_matrices=False)\n",
    "# A*x = b ==> U*S*V_t*x = b\n",
    "#         ==> S * V_t*x = U_t*b \n",
    "# As inv(U) is U_t\n",
    "S_V_t_x = torch.matmul(U.T, b)\n",
    "\n",
    "# V_t*x = inv(S) * U_t * b  \n",
    "# Inv(S) is easily computable as it is diagonal\n",
    "S_inv = torch.diag(1 / S) \n",
    "V_t_x = torch.matmul(S_inv, S_V_t_x)\n",
    "\n",
    "# x= V * inv(S) * U_t * b \n",
    "x_1 = torch.matmul(V_t.T, V_t_x)\n",
    "print(\"Solution via SVD:\\n{}\".format(x_1))\n",
    "\n",
    "assert torch.allclose(x_0, x_1)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
